{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Dataset Adapter\n",
    "\n",
    "With diverse dataset structures available, ensuring compatibility with SuperGradients (SG) can be challenging. This is where the DataloaderAdapter plays a pivotal role. This tutorial takes you through the importance, implementation, and advantages of using the DataloaderAdapter with SG."
   ],
   "metadata": {
    "id": "maykjDsh7d2x",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Why Dataset Adapter?\n",
    "\n",
    "Datasets come in a myriad of structures. However, SG requires data in a specific format.\n",
    "\n",
    "For instance, consider the Object Detection Format:\n",
    "\n",
    "Image format should be: (BS, H, W, C) i.e., channel last.\n",
    "Targets should be in the format: (BS, 6), where 6 represents (sample_id, class_id, label, cx, cy, w, h).\n",
    "The overhead of adjusting each dataset manually can be cumbersome. Enter DataloaderAdapter – designed to automatically understand your dataset structure and mold it for SG compatibility."
   ],
   "metadata": {
    "id": "rYLVw---7mgu",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "!pip install -q super-gradients==3.7.1"
   ],
   "metadata": {
    "id": "0puCRQGZSP8r",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "f244cb86-c7e5-419b-f0e1-f0807aac017d",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 8,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n",
      "  Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n",
      "  Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n",
      "  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m458.9/458.9 kB\u001B[0m \u001B[31m5.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m46.0/46.0 kB\u001B[0m \u001B[31m5.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m11.3/11.3 MB\u001B[0m \u001B[31m34.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m79.8/79.8 kB\u001B[0m \u001B[31m10.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m108.3/108.3 kB\u001B[0m \u001B[31m12.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25h  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m176.0/176.0 kB\u001B[0m \u001B[31m19.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m407.7/407.7 kB\u001B[0m \u001B[31m36.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m107.7/107.7 kB\u001B[0m \u001B[31m16.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m277.4/277.4 kB\u001B[0m \u001B[31m31.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m2.8/2.8 MB\u001B[0m \u001B[31m60.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m913.9/913.9 kB\u001B[0m \u001B[31m55.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25h  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m117.0/117.0 kB\u001B[0m \u001B[31m17.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25h  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m575.5/575.5 kB\u001B[0m \u001B[31m57.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.1/121.1 kB\u001B[0m \u001B[31m16.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m86.8/86.8 kB\u001B[0m \u001B[31m12.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m120.0/120.0 kB\u001B[0m \u001B[31m15.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m120.0/120.0 kB\u001B[0m \u001B[31m16.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m120.6/120.6 kB\u001B[0m \u001B[31m16.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m83.5/83.5 kB\u001B[0m \u001B[31m10.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m83.5/83.5 kB\u001B[0m \u001B[31m12.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m84.7/84.7 kB\u001B[0m \u001B[31m12.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.2/99.2 kB\u001B[0m \u001B[31m13.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.2/99.2 kB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.8/99.8 kB\u001B[0m \u001B[31m12.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.4/89.4 kB\u001B[0m \u001B[31m12.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.4/89.4 kB\u001B[0m \u001B[31m12.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m90.6/90.6 kB\u001B[0m \u001B[31m12.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m92.6/92.6 kB\u001B[0m \u001B[31m12.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m92.6/92.6 kB\u001B[0m \u001B[31m12.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m92.6/92.6 kB\u001B[0m \u001B[31m13.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m94.0/94.0 kB\u001B[0m \u001B[31m12.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m105.0/105.0 kB\u001B[0m \u001B[31m14.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m46.2/46.2 kB\u001B[0m \u001B[31m5.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m106.8/106.8 kB\u001B[0m \u001B[31m14.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m194.6/194.6 kB\u001B[0m \u001B[31m21.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m58.1/58.1 kB\u001B[0m \u001B[31m8.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25h  Building wheel for pycocotools (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for termcolor (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for treelib (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for coverage (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for xhtml2pdf (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for antlr4-python3-runtime (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for stringcase (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Building wheel for svglib (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "lida 0.0.10 requires fastapi, which is not installed.\n",
      "lida 0.0.10 requires kaleido, which is not installed.\n",
      "lida 0.0.10 requires python-multipart, which is not installed.\n",
      "lida 0.0.10 requires uvicorn, which is not installed.\n",
      "tensorflow 2.14.0 requires numpy>=1.23.5, but you have numpy 1.23.0 which is incompatible.\u001B[0m\u001B[31m\n",
      "\u001B[0m"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Why Do We Need the Dataset Adapter?\n",
    "\n",
    "While Datasets come in various structures and formats, SG expects data in a specific format to be able to run.\n",
    "\n",
    "\n",
    "> Example: Object Detection Format\n",
    "> - Image format: (BS, H, W, C) i.e. channel last\n",
    "> - Targets format: (BS, 6) where 6 represents (sample_id, class_id, label, cx, > cy, w, h).\n",
    "\n",
    "\n",
    "This means that you should either use one of SuperGradient's built-in Dataset class if it supports your dataset structure, or, if your dataset is too custom for it, inherit from SG datasets and bring all the required changes.\n",
    "\n",
    "While this is all right in most cases, it can be cumbersome when you just want to quickly experiment with a new dataset.\n",
    "\n",
    "To reduce this overhead, SuperGradients introduced the concept of `DataloaderAdapter`. Instead of requiring you to write all the transformations required to use SG, the `DataloaderAdapter` will infer anything possible directly from your data. Whenever something cannot be inferred with 100% confidence, you will be asked a question with all the required context for you to properly answer.\n",
    "\n",
    "Let's see this in practice with an example. Let's start with `SBDataset` dataset"
   ],
   "metadata": {
    "id": "VWKtR3sOfRuB",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Exemple 1 - Segmentation Adapter on `SBDataset` Dataset\n",
    "\n",
    "In this section, we'll walk through the process of preparing the `SBDataset` dataset for use in SuperGradients. We'll highlight the challenges and demonstrate how the Adapter can simplify the process.\n",
    "\n",
    "\n",
    "1. Preparing the Dataset without Adapter"
   ],
   "metadata": {
    "id": "dvbJpo5Z7w6n",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from torchvision.datasets import SBDataset\n",
    "\n",
    "try:\n",
    "  # There is a bug with `torchvision.datasets.SBDataset` that raises RuntimeError after downloading, so we just ignore it\n",
    "  SBDataset(root=\"data\", mode='segmentation', download=True)\n",
    "except RuntimeError:\n",
    "  pass"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BHg2-CiFTcx9",
    "outputId": "4043dee7-29c5-40aa-d2ad-0a10fc4fd1c6",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 3,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Downloading https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz to data/benchmark.tgz\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "100%|██████████| 1419539633/1419539633 [00:17<00:00, 79796528.12it/s] \n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Extracting data/benchmark.tgz to data\n",
      "Downloading https://www.cs.cornell.edu/~bharathh/ to data/train_noval.txt\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "21334it [00:00, 299478.84it/s]\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "oGHI8LgZSIiz",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torchvision.transforms import Compose, ToTensor, Resize, InterpolationMode\n",
    "\n",
    "\n",
    "transforms = Compose([ToTensor(), Resize((512, 512), InterpolationMode.NEAREST)])\n",
    "def sample_transform(image, mask):\n",
    "  return transforms(image), transforms(mask)\n",
    "\n",
    "train_set = SBDataset(root=\"data\", mode='segmentation', download=False, transforms=sample_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now let's see what we get when instantiating a `Dataloader`"
   ],
   "metadata": {
    "id": "SEuJd57v8ELj",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_loader = DataLoader(train_set, batch_size=20, shuffle=True)\n",
    "_images, labels = next(iter(train_loader))\n",
    "\n",
    "labels.unique()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AnJJNCSr8DUW",
    "outputId": "1f4fb2a5-fbe1-4604-ce6c-ff85079c9180",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0157, 0.0275, 0.0314, 0.0353, 0.0431, 0.0471, 0.0510, 0.0549,\n",
       "        0.0588, 0.0627, 0.0667, 0.0745])"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "As you can see, the labels are normalized (0-1). This is all right, but it is not the format expected by SuperGradients.\n",
    "\n",
    "Let's now see how the Adapter helps.\n",
    "\n",
    "2. Introducing Adapter\n",
    "\n",
    "The Adapter helps us skip manual data preparations and dives right into creating a dataloader that SuperGradients expects."
   ],
   "metadata": {
    "id": "Ex4-AmV474Nl",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from super_gradients.training.dataloaders.adapters import SegmentationDataloaderAdapterFactory\n",
    "\n",
    "train_loader = SegmentationDataloaderAdapterFactory.from_dataset(dataset=train_set, batch_size=20, shuffle=True, config_path='cache_file.json')\n",
    "\n",
    "_images, labels = next(iter(train_loader))\n",
    "labels.unique()"
   ],
   "metadata": {
    "id": "BVDBTQd_FZMe",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "9ac4683d-25ed-427a-cba0-e132f10bea77",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 11,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:56:24] INFO - data_config.py - Cache deactivated for `SegmentationDataConfig`.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mHow many classes does your dataset include?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "\n",
      "Enter your response >>> 21\n",
      "Great! \u001B[33;1mYou chose: `21`\u001B[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mDoes your dataset provide a batch or a single sample?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "Image shape: torch.Size([3, 512, 512])\n",
      "Mask shape: torch.Size([1, 512, 512])\n",
      "\u001B[34;1mOptions\u001B[0m:\n",
      "[\u001B[34;1m0\u001B[0m] | Batch of Samples (e.g. torch Dataloader)\n",
      "[\u001B[34;1m1\u001B[0m] | Single Sample (e.g. torch Dataset)\n",
      "\n",
      "Your selection (Enter the \u001B[34;1mcorresponding number\u001B[0m) >>> 1\n",
      "Great! \u001B[33;1mYou chose: `Single Sample (e.g. torch Dataset)`\u001B[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mIn which format are your images loaded ?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "\n",
      "\u001B[34;1mOptions\u001B[0m:\n",
      "[\u001B[34;1m0\u001B[0m] | RGB\n",
      "[\u001B[34;1m1\u001B[0m] | BGR\n",
      "[\u001B[34;1m2\u001B[0m] | LAB\n",
      "[\u001B[34;1m3\u001B[0m] | Other\n",
      "\n",
      "Your selection (Enter the \u001B[34;1mcorresponding number\u001B[0m) >>> 0\n",
      "Great! \u001B[33;1mYou chose: `RGB`\u001B[0m\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([  0,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,\n",
       "         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  28,  29,\n",
       "         30,  31,  32,  33,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,\n",
       "         45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,\n",
       "         59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,\n",
       "         73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,\n",
       "         87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100,\n",
       "        101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114,\n",
       "        115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,\n",
       "        129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,\n",
       "        143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,\n",
       "        157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,\n",
       "        171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,\n",
       "        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,\n",
       "        199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,\n",
       "        213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226,\n",
       "        227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240,\n",
       "        241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,\n",
       "        255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268,\n",
       "        269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282,\n",
       "        283, 284, 285, 286, 287, 288, 289, 290, 291, 293, 294, 295, 296, 297,\n",
       "        298, 299, 300, 301, 302, 303, 304, 305, 306, 308, 309, 310, 311, 312,\n",
       "        313, 314, 315, 316, 317, 319, 320, 321, 322, 323, 324, 325, 326, 327,\n",
       "        328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 342, 343,\n",
       "        345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,\n",
       "        359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372,\n",
       "        373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386,\n",
       "        387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398, 399, 400, 401,\n",
       "        402, 405, 406, 407, 409, 410, 411, 412, 413, 415, 416, 417, 418, 419,\n",
       "        420, 421, 423, 424, 425, 428, 429, 431, 435, 436, 437, 439, 440, 442,\n",
       "        443, 444, 446, 447, 448, 450, 451, 452, 454, 455, 457, 458, 459, 461,\n",
       "        462, 463, 465, 466, 467, 470, 474, 478, 481, 485, 489, 493])"
      ]
     },
     "metadata": {},
     "execution_count": 11
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "You can see that the mask is now encoded as `int`, which is the representation used in SuperGradients.\n",
    "\n",
    "It's important to note that the dataset adapter also support different dataset format such as one hot, ensuring that the output (`labels` here) is in the right format to use within SuperGradients."
   ],
   "metadata": {
    "id": "yMnGHW1kCO2h",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Example II - Detection Adapter on a Dictionary based Dataset\n",
    "\n",
    "Some datasets return a more complex data structure than the previous example.\n",
    "\n",
    "For instance, the `COCO` dataset implementation from `pytorch` returns a list of dictionaries representing the labels.\n",
    "\n",
    "Let's have a look:\n"
   ],
   "metadata": {
    "id": "-sEjb4d6jqIK",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Download the zip file\n",
    "!wget https://deci-pretrained-models.s3.amazonaws.com/coco2017_small.zip\n",
    "\n",
    "# Unzip the downloaded file\n",
    "!unzip coco2017_small.zip > /dev/null"
   ],
   "metadata": {
    "id": "7nb2DFNhbHff",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "aaf48023-dc04-4f72-9318-ffa15f5eb6b2",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 12,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "--2023-11-13 13:56:46--  https://deci-pretrained-models.s3.amazonaws.com/coco2017_small.zip\n",
      "Resolving deci-pretrained-models.s3.amazonaws.com (deci-pretrained-models.s3.amazonaws.com)... 52.216.211.169, 52.216.246.28, 3.5.2.158, ...\n",
      "Connecting to deci-pretrained-models.s3.amazonaws.com (deci-pretrained-models.s3.amazonaws.com)|52.216.211.169|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 246116231 (235M) [application/zip]\n",
      "Saving to: ‘coco2017_small.zip’\n",
      "\n",
      "coco2017_small.zip  100%[===================>] 234.71M  39.5MB/s    in 6.3s    \n",
      "\n",
      "2023-11-13 13:56:53 (37.0 MB/s) - ‘coco2017_small.zip’ saved [246116231/246116231]\n",
      "\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "from torchvision.datasets import CocoDetection\n",
    "from torchvision.transforms import Compose, ToTensor, Resize, InterpolationMode\n",
    "from torchvision.datasets import SBDataset\n",
    "\n",
    "\n",
    "image_transform = Compose([ToTensor(), Resize((512, 512))])\n",
    "\n",
    "train_set = CocoDetection(root='coco2017_small/images/train2017', annFile='coco2017_small/annotations/instances_train2017.json', transform=image_transform)\n",
    "val_set = CocoDetection(root='coco2017_small/images/val2017', annFile='coco2017_small/annotations/instances_val2017.json', transform=image_transform)\n",
    "image, targets = next(iter(train_set))"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "k8ox8EbVbxVU",
    "outputId": "72cf85a5-e143-436d-9d0e-12ce2e8d8742",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 13,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "loading annotations into memory...\n",
      "Done (t=0.13s)\n",
      "creating index...\n",
      "index created!\n",
      "loading annotations into memory...\n",
      "Done (t=0.06s)\n",
      "creating index...\n",
      "index created!\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "print(f\"Number of targets: {len(targets)}, First target structure: {targets[0]}\")"
   ],
   "metadata": {
    "id": "BhuJfMHM9g-a",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 14,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "Observe the dataset output's nested dictionary structure? This complicates things for the Dataset Adapter as it's unsure which fields detail the bounding box.\n",
    "\n",
    "To solve this, we utilize an extractor function.\n",
    "\n",
    "#### The Extractor's Role\n",
    "\n",
    "Simply put, the extractor translates your dataset's output into a format the Adapter understands. For our dataset, it will take the image and annotations, then return the bounding box data, including the label and coordinates.\n",
    "\n",
    "Worried about bounding box format like `xyxy_label` or `label_xywh`? Don't be. The Adapter is designed to recognize them.\n",
    "\n",
    "> For further guidance on extractor functions, see the [official documentation](https://github.com/Deci-AI/data-gradients/blob/master/documentation/dataset_extractors.md)."
   ],
   "metadata": {
    "id": "VKWqK2OwdbS9",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "def coco_labels_extractor(sample) -> torch.Tensor:\n",
    "    _, annotations = sample # annotations = [{\"bbox\": [1.08, 187.69, 611.59, 285.84], \"category_id\": 51}, ...]\n",
    "    labels = []\n",
    "    for annotation in annotations:\n",
    "        class_id = annotation[\"category_id\"]\n",
    "        bbox = annotation[\"bbox\"]\n",
    "        labels.append((class_id, *bbox))\n",
    "    return torch.Tensor(labels) # torch.Tensor([[51, 1.08, 187.69, 611.59, 285.84], ...])\n",
    "\n",
    "coco_labels_extractor(sample=next(iter(train_set)))"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JkmopVSocq9e",
    "outputId": "e51572d9-e066-4639-e02b-3cc5ea5da167",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 15,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[ 51.0000,   1.0800, 187.6900, 611.5900, 285.8400],\n",
       "        [ 51.0000, 311.7300,   4.3100, 319.2800, 228.6800],\n",
       "        [ 56.0000, 249.6000, 229.2700, 316.2400, 245.0800],\n",
       "        [ 51.0000,   0.0000,  13.5100, 434.4800, 375.1200],\n",
       "        [ 55.0000, 376.2000,  40.3600,  75.5500,  46.5300],\n",
       "        [ 55.0000, 465.7800,  38.9700,  58.0700,  46.6700],\n",
       "        [ 55.0000, 385.7000,  73.6600,  84.0200,  70.5100],\n",
       "        [ 55.0000, 364.0500,   2.4900,  94.7600,  71.0700]])"
      ]
     },
     "metadata": {},
     "execution_count": 15
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "This output is all you need to get started. Now we can use the Dataloader Adapters!"
   ],
   "metadata": {
    "id": "vz97TRpZj451",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from super_gradients.training.dataloaders.adapters import DetectionDataloaderAdapterFactory\n",
    "from data_gradients.dataset_adapters.config.data_config import DetectionDataConfig\n",
    "\n",
    "\n",
    "adapter_config = DetectionDataConfig(labels_extractor=coco_labels_extractor, cache_path=\"coco_adapter_cache.json\")\n",
    "train_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")\n",
    "val_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5cOJAT81ZfhO",
    "outputId": "578a4746-7bef-4189-d9e6-b92ccbec6b94",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 18,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:57:52] INFO - data_config.py - Cache deactivated for `DetectionDataConfig`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mHow many classes does your dataset include?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "\n",
      "Enter your response >>> 80\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:57:55] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Great! \u001B[33;1mYou chose: `80`\u001B[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mIn which format are your images loaded ?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "\n",
      "\u001B[34;1mOptions\u001B[0m:\n",
      "[\u001B[34;1m0\u001B[0m] | RGB\n",
      "[\u001B[34;1m1\u001B[0m] | BGR\n",
      "[\u001B[34;1m2\u001B[0m] | LAB\n",
      "[\u001B[34;1m3\u001B[0m] | Other\n",
      "\n",
      "Your selection (Enter the \u001B[34;1mcorresponding number\u001B[0m) >>> 0\n",
      "Great! \u001B[33;1mYou chose: `RGB`\u001B[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1m\u001B[33;1mWhich comes first\u001B[0m in your annotations, the class id or the bounding box?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "Here's a sample of how your labels look like:\n",
      "Each line corresponds to a bounding box.\n",
      "tensor([[ 51.0000,   1.0800, 187.6900, 611.5900, 285.8400],\n",
      "        [ 51.0000, 311.7300,   4.3100, 319.2800, 228.6800],\n",
      "        [ 56.0000, 249.6000, 229.2700, 316.2400, 245.0800],\n",
      "        [ 51.0000,   0.0000,  13.5100, 434.4800, 375.1200]])\n",
      "\u001B[34;1mOptions\u001B[0m:\n",
      "[\u001B[34;1m0\u001B[0m] | Label comes first (e.g. [class_id, x1, y1, x2, y2])\n",
      "[\u001B[34;1m1\u001B[0m] | Bounding box comes first (e.g. [x1, y1, x2, y2, class_id])\n",
      "\n",
      "Your selection (Enter the \u001B[34;1mcorresponding number\u001B[0m) >>> 0\n",
      "Great! \u001B[33;1mYou chose: `Label comes first (e.g. [class_id, x1, y1, x2, y2])`\u001B[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001B[33;1mWhat is the \u001B[33;1mbounding box format\u001B[0m?\u001B[0m\n",
      "--------------------------------------------------------------------------------\n",
      "Here's a sample of how your labels look like:\n",
      "Each line corresponds to a bounding box.\n",
      "tensor([[ 51.0000,   1.0800, 187.6900, 611.5900, 285.8400],\n",
      "        [ 51.0000, 311.7300,   4.3100, 319.2800, 228.6800],\n",
      "        [ 56.0000, 249.6000, 229.2700, 316.2400, 245.0800],\n",
      "        [ 51.0000,   0.0000,  13.5100, 434.4800, 375.1200]])\n",
      "\u001B[34;1mOptions\u001B[0m:\n",
      "[\u001B[34;1m0\u001B[0m] | xyxy: x-left, y-top, x-right, y-bottom\t\t(Pascal-VOC format)\n",
      "[\u001B[34;1m1\u001B[0m] | xywh: x-left, y-top, width, height\t\t\t(COCO format)\n",
      "[\u001B[34;1m2\u001B[0m] | cxcywh: x-center, y-center, width, height\t\t(YOLO format)\n",
      "\n",
      "Your selection (Enter the \u001B[34;1mcorresponding number\u001B[0m) >>> 1\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:58:01] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Great! \u001B[33;1mYou chose: `xywh: x-left, y-top, width, height\t\t\t(COCO format)`\u001B[0m\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "_image, targets = next(iter(train_loader))\n",
    "targets.shape # [N, 6] format with 6 representing (sample_id, class_id, cx, cy, w, h)"
   ],
   "metadata": {
    "id": "m1S8n5YzgOKM",
    "pycharm": {
     "name": "#%%\n"
    },
    "outputId": "8749f639-78e3-4d20-cb5b-fa449ba9e0c8",
    "colab": {
     "base_uri": "https://localhost:8080/"
    }
   },
   "execution_count": 19,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([14, 6])"
      ]
     },
     "metadata": {},
     "execution_count": 19
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# III. Use your Adapted Dataloader to train a model"
   ],
   "metadata": {
    "id": "k6mRrkgF_zL7",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now that we have an adapter for a detection dataset, let's use it to launch a training of YoloNAS on it!\n",
    "\n",
    "This is of course for the sake of the example, since YoloNAS was originally trained using the SuperGradients implementation of COCO Dataset. You can replace the `COCO` dataset with any of your dataset."
   ],
   "metadata": {
    "id": "_B5WQlBwgjdu",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from omegaconf import OmegaConf\n",
    "from hydra.utils import instantiate\n",
    "\n",
    "from super_gradients import Trainer\n",
    "from super_gradients.training import models\n",
    "from super_gradients.common.object_names import Models\n",
    "from super_gradients.training import training_hyperparams\n",
    "from super_gradients.common.environment.cfg_utils import load_recipe\n",
    "\n",
    "\n",
    "trainer = Trainer(experiment_name=\"yolonas_training_with_adapter\", ckpt_root_dir=\"./\")\n",
    "model = models.get(model_name=Models.YOLO_NAS_S, num_classes=adapter_config.n_classes, pretrained_weights=\"coco\")\n",
    "\n",
    "yolonas_recipe = load_recipe(config_name=\"coco2017_yolo_nas_s\", overrides=[f\"arch_params.num_classes={adapter_config.n_classes}\", \"training_hyperparams.max_epochs=1\", \"training_hyperparams.mixed_precision=False\"])\n",
    "yolonas_recipe = OmegaConf.to_container(instantiate(yolonas_recipe))\n",
    "training_params = yolonas_recipe['training_hyperparams']\n",
    "\n",
    "trainer.train(model=model, training_params=training_params, train_loader=train_loader, valid_loader=val_loader)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EKvvqGvQC_k0",
    "outputId": "a9c90ec9-8ebb-492a-d2ca-31b9490ff7b3",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 19,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 08:29:34] INFO - checkpoint_utils.py - License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n",
      "https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS.md\n",
      "By downloading the pre-trained weight files you agree to comply with these terms.\n",
      "Downloading: \"https://sghub.deci.ai/models/yolo_nas_s_coco.pth\" to /root/.cache/torch/hub/checkpoints/yolo_nas_s_coco.pth\n",
      "100%|██████████| 73.1M/73.1M [00:02<00:00, 27.5MB/s]\n",
      "[2023-11-13 08:29:37] INFO - checkpoint_utils.py - Successfully loaded pretrained weights for architecture yolo_nas_s\n",
      "[2023-11-13 08:29:38] INFO - sg_trainer.py - Starting a new run with `run_id=RUN_20231113_082938_239280`\n",
      "[2023-11-13 08:29:38] INFO - sg_trainer.py - Checkpoints directory: ./yolonas_training_with_adapter/RUN_20231113_082938_239280\n",
      "[2023-11-13 08:29:38] INFO - sg_trainer.py - Using EMA with params {'decay': 0.9997, 'decay_type': 'threshold', 'beta': 15}\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The console stream is now moved to ./yolonas_training_with_adapter/RUN_20231113_082938_239280/console_Nov13_08_29_38.txt\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 08:29:38] WARNING - callbacks.py - Number of warmup steps (1000) is greater than number of steps in epoch (100). Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers.\n",
      "[2023-11-13 08:29:38] INFO - sg_trainer_utils.py - TRAINING PARAMETERS:\n",
      "    - Mode:                         Single GPU\n",
      "    - Number of GPUs:               0          (0 available on the machine)\n",
      "    - Full dataset size:            500        (len(train_set))\n",
      "    - Batch size per GPU:           5          (batch_size)\n",
      "    - Batch Accumulate:             1          (batch_accumulate)\n",
      "    - Total batch size:             5          (num_gpus * batch_size)\n",
      "    - Effective Batch size:         5          (num_gpus * batch_size * batch_accumulate)\n",
      "    - Iterations per epoch:         100        (len(train_loader))\n",
      "    - Gradient updates per epoch:   100        (len(train_loader) / batch_accumulate)\n",
      "\n",
      "[2023-11-13 08:29:38] INFO - sg_trainer.py - Started training for 1 epochs (0/0)\n",
      "\n",
      "Train epoch 0:   1%|          | 1/100 [00:16<26:48, 16.25s/it, PPYoloELoss/loss=5.34, PPYoloELoss/loss_cls=2.84, PPYoloELoss/loss_dfl=1.29, PPYoloELoss/loss_iou=1.21, gpu_mem=0]"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# IV. Dig deeper into the Adapter\n",
    "\n",
    "By default, any parameter that could not be confidently infered will trigger a question.\n",
    "\n",
    "But you have the possibility to set these parameters in advance through the config object. In the previous example we had to set `labels_extractor` explicitly. Now let's set all the parameters"
   ],
   "metadata": {
    "id": "JWUsWehVmowy",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from super_gradients.training.dataloaders.adapters import DetectionDataloaderAdapterFactory\n",
    "from data_gradients.dataset_adapters.config.data_config import DetectionDataConfig\n",
    "from data_gradients.utils.data_classes.image_channels import ImageChannels\n",
    "class_names = [category['name'] for category in train_set.coco.loadCats(train_set.coco.getCatIds())]\n",
    "\n",
    "adapter_config = DetectionDataConfig(\n",
    "    labels_extractor=coco_labels_extractor,\n",
    "    is_label_first=True,\n",
    "    class_names=class_names,\n",
    "    image_channels=ImageChannels.from_str(\"RGB\"),\n",
    "    xyxy_converter='xywh',\n",
    "    cache_path=\"coco_adapter_cache_with_default.json\"\n",
    ")"
   ],
   "metadata": {
    "id": "wsV8JpBJmyIH",
    "pycharm": {
     "name": "#%%\n"
    },
    "outputId": "3b963871-663a-4155-cfc3-7654fbd7f255",
    "colab": {
     "base_uri": "https://localhost:8080/"
    }
   },
   "execution_count": 20,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:58:08] INFO - data_config.py - Cache deactivated for `DetectionDataConfig`.\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "This can now be used and you don't need to answer any question"
   ],
   "metadata": {
    "id": "__YWkBVas6Yp",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "train_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")\n",
    "val_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")\n",
    "\n",
    "_image, targets = next(iter(train_loader))\n",
    "targets.shape # [N, 6] format with 6 representing (sample_id, class_id, cx, cy, w, h)"
   ],
   "metadata": {
    "id": "uZDZg14InB3U",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "7fc801c7-abac-4319-ad68-afef23e6df29",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 21,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:58:11] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n",
      "[2023-11-13 13:58:11] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([14, 6])"
      ]
     },
     "metadata": {},
     "execution_count": 21
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Load from existing cache\n",
    "\n",
    "You can use the cache of an adapter you already used in the past. This will allow you skip the questions that were already asked in the previous run."
   ],
   "metadata": {
    "id": "otTI_sVxtxvf",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# The new config will load the answer from questions asked in the previous run.\n",
    "adapter_config = DetectionDataConfig(\n",
    "    labels_extractor=coco_labels_extractor,\n",
    "    cache_path=\"coco_adapter_cache_with_default.json\" # Name of the previous cache\n",
    ")\n",
    "\n",
    "train_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")\n",
    "val_loader = DetectionDataloaderAdapterFactory.from_dataset(\n",
    "    dataset=train_set,\n",
    "    config=adapter_config,\n",
    "    batch_size=5,\n",
    "    drop_last=True,\n",
    ")\n",
    "\n",
    "_image, targets = next(iter(train_loader))"
   ],
   "metadata": {
    "id": "y3I00k-2svd3",
    "pycharm": {
     "name": "#%%\n"
    },
    "outputId": "3c5bd8b4-3781-4657-d0ef-a40e896c6a4a",
    "colab": {
     "base_uri": "https://localhost:8080/"
    }
   },
   "execution_count": 22,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "[2023-11-13 13:58:14] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n",
      "[2023-11-13 13:58:14] INFO - detection_adapter_collate_fn.py - You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "targets.shape # [N, 6] format with 6 representing (sample_id, class_id, cx, cy, w, h)"
   ],
   "metadata": {
    "id": "988EZEJpU3bf",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "d6bfd31b-5b86-4d3a-b2b9-94bb4f40b52f",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 23,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([14, 6])"
      ]
     },
     "metadata": {},
     "execution_count": 23
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "As you can see, no question was asked and we still get the targets adapted into the SuperGradients format."
   ],
   "metadata": {
    "id": "1Fw4yhwkuK7w",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ]
}
